import os
import soundfile as sf
import numpy as np
# 计算信噪比
def SNR_singlech(clean_wav, noisy_wav,aaa):
    est_noise = noisy_wav-clean_wav
    if max(clean_wav)>1:
        print("原始音频剪切问题---")
    if max(noisy_wav)>1:
        print("对抗样本生产产生问题")
    # 计算信噪比
    SNR = 10 * np.log10((np.sum(clean_wav ** 2)) / (np.sum(est_noise ** 2)))
    return SNR

def cal_mean_SNR(attack_dir,advsae_dir):
    alladvname=os.listdir(advsae_dir)
    allSNR=0
    for item in alladvname:
        prewav,sr=sf.read(os.path.join(advsae_dir,item))
        prewav=np.clip(prewav,a_min=-1,a_max=1)
        cleanwav, sr = sf.read(os.path.join(attack_dir, item))
        aaa=os.path.join(attack_dir,item)
        snr=SNR_singlech(clean_wav=cleanwav,noisy_wav=prewav,aaa=aaa)
        allSNR+=snr
    print("平均信噪比 为:",str(allSNR/len(alladvname)))



#
# clean_file = Config.success_original_path
# noisy_file = Config.MFCC_LCNN_FGSMAdvsetPath
# for i in range(len(os.listdir(clean_file))):
#     for root, dirs, files in os.walk(clean_file):
#         clean_wav, sr = sf.read(os.path.join(clean_file, files[0]))
#         noisy_wav, sr = sf.read(os.path.join(noisy_file, files[0]))
#         librosa.display.waveshow(clean_wav, sr)
#         plt.show()
#         # SNR_singlech(clean_wav, noisy_wav)
# for root, dirs, files in os.walk(clean_file):
#     plt.figure()
#     plt.subplot(2, 2, 1)
#     clean_wav, sr = sf.read(os.path.join(clean_file, files[0]))
#     noisy_wav, sr = sf.read(os.path.join(noisy_file, files[0]))
#     librosa.display.waveshow(clean_wav, sr)
#     plt.subplot(2, 2, 2)
#     librosa.display.waveshow(noisy_wav, sr)
#     plt.subplot(2,2, 3)
#     librosa.display.waveshow(noisy_wav-clean_wav,sr)
#     plt.xlim(0, 4)
#     plt.ylim(-1, 1)
#     plt.show()